#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import ast
import json
import re

from openai import OpenAI

SPECIFY_OUTPUT_FORMAT_PROMPT = """You are an AI assistant with the role name "assistant." \
Based on the provided API specifications and conversation history from steps 1 to t, \
generate the API requests that the assistant should call in step t+1. \
The API requests should be output in the format [api_name(key1='value1', key2='value2', ...)], \
replacing api_name with the actual API name, key1, key2, etc., with the actual parameter names, \
and value1, value2, etc., with the actual parameter values. The output should start with a square bracket "[" and end with a square bracket "]".
If there are multiple API requests, separate them with commas, for example: \
[api_name(key1='value1', key2='value2', ...), api_name(key1='value1', key2='value2', ...), ...]. \
Do not include any other explanations, prompts, or API call results in the output.
If the API parameter description does not specify otherwise, the parameter is optional \
(parameters mentioned in the user input need to be included in the output; if not mentioned, they do not need to be included).
If the API parameter description does not specify the required format for the value, use the user's original text for the parameter value. \
If the API requires no parameters, output the API request directly in the format [api_name()], and do not invent any nonexistent parameter names.

API Specifications:
{tools}"""

NOT_SPECIFY_OUTPUT_FORMAT_PROMPT = """Important: Only give the tool call requests, \
do not include any other explanations, prompts, or API call results in the output.
The tool call requests generated by you are wrapped by \
<|tool_calls_section_begin|> and <|tool_calls_section_end|>, with each tool call wrapped by <|tool_call_begin|> and <|tool_call_end|>. \
The tool ID and arguments are separated by <|tool_call_argument_begin|>. The format of the tool ID is functions.func_name:idx, \
from which we can parse the function name.

API Specifications:
{tools}"""


def get_weather(location: str):
    if location.lower() == "beijing":
        return "Sunny"
    elif location.lower() == "shanghai":
        return "Cloudy"
    else:
        return "Rainy"


# Tool name->object mapping for easy calling later
tool_map = {"get_weather": get_weather}


# ref: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
def extract_tool_call_info(tool_call_rsp: str):
    if '<|tool_calls_section_begin|>' not in tool_call_rsp:
        # No tool calls
        return []
    pattern = r"<\|tool_calls_section_begin\|>(.*?)<\|tool_calls_section_end\|>"

    tool_calls_sections = re.findall(pattern, tool_call_rsp, re.DOTALL)

    # Extract multiple tool calls
    func_call_pattern = r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>"
    tool_calls = []
    for match in re.findall(func_call_pattern, tool_calls_sections[0],
                            re.DOTALL):
        function_id, function_args = match
        # function_id: functions.get_weather:0
        function_name = function_id.split('.')[1].split(':')[0]
        tool_calls.append({
            "id": function_id,
            "type": "function",
            "function": {
                "name": function_name,
                "arguments": function_args
            }
        })
    return tool_calls


def parse_specified_format_tool_calls(text: str):
    pattern = re.compile(r'(\w+)\s*\(([^)]*)\)')
    tool_calls = []

    for m in pattern.finditer(text):
        api_name, kv_body = m.group(1), m.group(2)

        kv_pattern = re.compile(r'(\w+)\s*=\s*([^,]+)')
        kwargs = {}
        for k, v in kv_pattern.findall(kv_body):
            try:
                kwargs[k] = ast.literal_eval(v.strip())
            except Exception:
                kwargs[k] = v.strip()

        tool_calls.append({
            "type": "function",
            "function": {
                "name": api_name,
                "arguments": kwargs
            }
        })

    return tool_calls


def get_tools():
    # Collect the tool descriptions in tools
    return [{
        "type": "function",
        "function": {
            "name": "get_weather",
            "description":
            "Get weather information. Call this tool when the user needs to get weather information",
            "parameters": {
                "type": "object",
                "required": ["location"],
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "Location name",
                    }
                }
            }
        }
    }]


def get_tool_call_requests(args, client):
    model = args.model
    tools = get_tools()
    system_prompt = SPECIFY_OUTPUT_FORMAT_PROMPT if args.specify_output_format else NOT_SPECIFY_OUTPUT_FORMAT_PROMPT.format(
        tools=tools)
    messages = [{
        "role": "system",
        "content": system_prompt
    }, {
        "role": "user",
        "content": args.prompt
    }]

    response = client.chat.completions.create(model=model,
                                              messages=messages,
                                              max_tokens=256,
                                              temperature=0.0)

    output = response.choices[0].message.content
    tool_calls = parse_specified_format_tool_calls(
        output) if args.specify_output_format else extract_tool_call_info(
            output)
    print(f"[The original output from Kimi-K2]: {output}\n")
    print(f"[The tool-call requests parsed from the output]: {tool_calls}\n")
    return tool_calls, messages


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        type=str,
                        default="moonshotai/Kimi-K2-Instruct")
    parser.add_argument("--prompt",
                        type=str,
                        default="What's the weather like in Shanghai today?")
    parser.add_argument("--specify_output_format",
                        action="store_true",
                        default=False)

    args = parser.parse_args()

    # start trt-llm server before running this script
    client = OpenAI(
        api_key="tensorrt_llm",
        base_url="http://localhost:8000/v1",
    )

    tool_calls, messages = get_tool_call_requests(args, client)

    for tool_call in tool_calls:
        tool_name = tool_call['function']['name']
        if args.specify_output_format:
            tool_arguments = tool_call['function']['arguments']
        else:
            tool_arguments = json.loads(tool_call['function']['arguments'])
        tool_function = tool_map[tool_name]
        tool_result = tool_function(**tool_arguments)
        print(
            f"[Tool call result]: tool_name={tool_name}, tool_result={tool_result}\n"
        )
